-
Notifications
You must be signed in to change notification settings - Fork 118
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
GroupedPredictor
refactoring
#618
Conversation
|
||
# The grouping part we always want as a DataFrame with range index | ||
return X_group.reset_index(drop=True) | ||
|
||
|
||
def _get_estimator(estimators, grp_values, grp_names, return_level, fallback_method): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The point of this function is to determine which estimator to use to predict.
- if
fallback_method ="raise"
, we have to have the model for the group we are predicting - if
fallback_method ="next"
, we check for recursively for the first available parent - if
fallback_method ="global"
, we summon the global model.
The point of returning a return_level
is a trick to know how far back we went, and used to slice an array afterwards (more comments where this happens)
|
||
if y is not None: | ||
y = check_array(y, ensure_2d=False) | ||
# TODO: Validate class params? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Open question
if self.shrinkage is not None: | ||
self.__set_shrinkage_function() | ||
if is_classifier(self.estimator): | ||
self.classes_ = np.sort(np.unique(y)) # TODO: Must be sequential for the rest of the code to work |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If y
has classes in random order we fall short, should we enforce that?
# TODO: __grouped_predictor_target_value__? | ||
frame = pd.DataFrame(X).assign(__target_value__=np.array(y)).reset_index(drop=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
__target_value__
(and __global_model__
just right after) could be "safe enough" column names, but we can even prefix with __grouped_predictor
or __sklego_grouped_predictor
to be extra safe that someone is not using those as column names 😂
@property | ||
def n_levels_(self): | ||
check_is_fitted(self, ["fitted_levels_"]) | ||
return len(self.fitted_levels_) | ||
|
||
def fit(self, X, y=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The fit routine does the following:
- Creates a dataframe with
X
andy
with index beingrange(0, len(x)
- If
X
was already a dataframe it maintains the columns and we just overwrite the index - If
X
was an array, the column names will coincide with the column indexes
- If
- Do some checking on the
X
,y
values - Add a dummy global model column if required (with a fixed value of 1)
- Based on the arguments (
use_global_model
,shrinkage
andfallback_method
) we determine which levels/models need to be fitted by creating a list of lists, where the inner lists are the columns to groupby on - We train a model for each one of this level/group and their values
- End up with a dict of
key=group_value, value=fitted_estimator
- End up with a dict of
- Define the shrinkage function and factors.
- If shrinkage is None, instead of doing nothing, I add a factor which is zero everywhere expect for the model trained to be 1.
|
||
return self | ||
|
||
def __set_fit_levels(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Send help here 😁
for grp_names in self.fitted_levels_: | ||
for grp_values, grp_frame in frame.groupby(grp_names): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Example here could be:
groups = ["a", "b"]
fitted_levels_ = [["__global_model__"], ["__global_model__", "a"], ["__global_model__", "a", "b"]]
hence grp_names
goes from outer to inner and grp_values
are the unique values identifying the group
if result.shape != expected_shape: | ||
raise ValueError(f"shrinkage_function({group_lengths}).shape should be {expected_shape}") | ||
|
||
def __predict_estimators(self, X, method_name): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a bit of a headache but allows for a unified approach:
preds
is a 3d array of(n_samples, n_levels, n_classes)
where for regression or odn_classes
can be considered to be 1. This will be populated with a prediction for each levelshrinkage
is a 2d array(n_samples, n_levels)
as shrinkage factor is already one per level. This will get the shrinkage to use for each sample. Since we are doing predictions from outer to inner most level, we are overwriting when needed.- The
_get_estimator
returning the level allows to select which model shrinkage to use for that particular prediction. - Finally we multiply
preds
andshrinkage
and sum over all levels.- If shrinkage is none, the array will be only zeros and ones, therefore it is equivalent of selecting the model to use
- If shrinkage is not none, then it is equivalent over averaging with shrinkage factors.
Remark: last_dim_ix
is used for the case in #579, using the estimator classes let us index the columns/classes to which assign the results
Decision Function
Decision function breaks for the following reasons:
- For binary classification, it returns a 1D array
- For multiclass it returns a 2D array
Therefore for the mix case of group A [0,1,2] and group B [0, 3] it has two different output shapes and most importantly different meaning.
However:
- For the "normal" multiclass cases this implementation works fine, since we are filling the whole
preds
- For binary classification it breaks due to the fact that the initialization of
preds
is which n_classes (=2), yet it is enough to treat this as the regression case. - The mixed case is just painful - and actually wrong in the current api as well
@@ -45,13 +46,15 @@ def _split_groups_and_values( | |||
_shape_check(X, min_value_cols) | |||
|
|||
try: | |||
lgroups = as_list(groups) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess we should set group_sizes
to also be a non-list type in the function definition? Bit of a nit this one.
Also: maybe groups_list
instead of lgroups
.
) | ||
return _get_estimator(estimators, grp_values[:-1], grp_names[:-1], return_level - 1, fallback_method) | ||
|
||
else: # fallback_method == "global" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: technically the else
isn't needed anymore because the function would've returned otherwise. Might be nicer to confirm at the start of the function that a correct fallback method is chosen.
Just noticed we check this elsewhere, so it's probably fine to not check here.
Description
This is a first attempt to refactor
GroupedPredictor
class (personally one of my favorite features 😁) following the issue raised in #616 .While working on this I noticed another set of issues with the implementation:
.decision_function(...)
will be a 1d or 2d array (resp.), which when concatenated lead to wrong behaviour.groups=["a", "b"]
with values(0, 0)
,(0, 1)
and(1, 0)
. Currently if at prediction time we encounter(a=0, b=2)
for which we don't have a trained model, we are falling back to the global one. However I would argue that we should fallback to the model trained ona=0
.Type of change
Checklist:
TODOs